# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import argparse
import os
import zipfile
from copy import deepcopy
from pathlib import Path

import numpy as np
import pytest
import torch
import torch.nn.functional as F

from tensordict import (
    is_tensor_collection,
    MemoryMappedTensor,
    TensorDict,
    TensorDictBase,
)
from tensordict.nn import TensorDictModule
from torchrl.data.llm import TensorDictTokenizer
from torchrl.data.llm.dataset import (
    _has_datasets,
    _has_transformers,
    get_dataloader,
    TokenizedDatasetLoader,
)
from torchrl.data.llm.prompt import PromptData, PromptTensorDictTokenizer
from torchrl.data.llm.reward import PairwiseDataset, pre_tokenization_hook
from torchrl.data.llm.utils import RolloutFromModel
from torchrl.modules.models.llm import GPT2RewardModel

if os.getenv("PYTORCH_TEST_FBCODE"):
    from pytorch.rl.test._utils_internal import get_default_devices
else:
    from _utils_internal import get_default_devices

HERE = Path(__file__).parent


@pytest.fixture
def tmpdir1(tmp_path_factory):
    yield tmp_path_factory.mktemp("tmpdir1")


@pytest.fixture(scope="session")
def minidata_dir_comparison(tmp_path_factory):
    dest = tmp_path_factory.mktemp("comparisons")
    dataset_path = HERE / "assets" / "openai_summarize_comparisons.zip"
    with zipfile.ZipFile(dataset_path, "r") as zip_ref:
        zip_ref.extractall(dest)
        yield dest / Path(dataset_path).stem


@pytest.fixture(scope="session")
def minidata_dir_tldr(tmp_path_factory):
    dest = tmp_path_factory.mktemp("tldr")
    dataset_path = HERE / "assets" / "openai_summarize_tldr.zip"
    with zipfile.ZipFile(dataset_path, "r") as zip_ref:
        zip_ref.extractall(dest)
        yield dest / Path(dataset_path).stem


@pytest.fixture(scope="session")
def tldr_batch_dir(tmp_path_factory):
    dest = tmp_path_factory.mktemp("tldr_batch")
    dataset_path = HERE / "assets" / "tldr_batch.zip"
    with zipfile.ZipFile(dataset_path, "r") as zip_ref:
        zip_ref.extractall(dest)
        yield dest / Path(dataset_path).stem
    from torchrl._utils import print_directory_tree

    print_directory_tree(dest)


@pytest.mark.skipif(
    not (_has_transformers and _has_datasets), reason="missing dependencies"
)
@pytest.mark.parametrize("max_length", [12, 550])
@pytest.mark.parametrize(
    "dataset,make_process_fn,pre_tokenization_hook",
    [
        (
            "comp",
            TensorDictTokenizer,
            pre_tokenization_hook,
        ),
        (
            "tldr",
            PromptTensorDictTokenizer,
            None,
        ),
    ],
)
def test_create_or_load_dataset(
    tmpdir1,
    minidata_dir_tldr,
    minidata_dir_comparison,
    max_length,
    dataset,
    make_process_fn,
    pre_tokenization_hook,
    mocker,
):
    # test caching of the values
    lmemmap_save = deepcopy(TensorDict.load_memmap)
    mocked_hello = mocker.patch("tensordict.TensorDict.load_memmap")
    mocked_hello.side_effect = lmemmap_save
    if dataset == "tldr":
        dataset = minidata_dir_tldr
    elif dataset == "comp":
        dataset = minidata_dir_comparison
    else:
        raise NotImplementedError

    for i in range(2):
        data = TokenizedDatasetLoader(
            split="train",
            max_length=max_length,
            dataset_name=dataset,
            tokenizer_fn=make_process_fn,
            pre_tokenization_hook=pre_tokenization_hook,
            from_disk=True,
            root_dir=tmpdir1,
        ).load()
        if i == 0:
            mocked_hello.assert_not_called()
        else:
            mocked_hello.assert_called()

        assert isinstance(data, TensorDict)
        # assert "train" in data.keys(), data
        # assert ("train", str(max_length)) in data.keys(True), data
        for val in data.values(True, True):
            if val.ndim > 1:
                assert val.shape[1] == max_length


@pytest.mark.skipif(
    not (_has_transformers and _has_datasets), reason="missing dependencies"
)
@pytest.mark.parametrize("max_length", [12, 550])
@pytest.mark.parametrize(
    "dataset,make_process_fn,pre_tokenization_hook,split",
    [
        ("comp", TensorDictTokenizer, pre_tokenization_hook, "train"),
        ("comp", TensorDictTokenizer, pre_tokenization_hook, "valid1"),
        ("tldr", PromptTensorDictTokenizer, None, "train"),
        ("tldr", PromptTensorDictTokenizer, None, "valid"),
    ],
)
def test_preproc_data(
    tmpdir1,
    max_length,
    dataset,
    make_process_fn,
    pre_tokenization_hook,
    minidata_dir_tldr,
    minidata_dir_comparison,
    split,
):
    import datasets

    if dataset == "tldr":
        dataset = minidata_dir_tldr
    elif dataset == "comp":
        dataset = minidata_dir_comparison
    else:
        raise NotImplementedError
    loader = TokenizedDatasetLoader(
        split=split,
        max_length=max_length,
        dataset_name=dataset,
        tokenizer_fn=make_process_fn,
        pre_tokenization_hook=pre_tokenization_hook,
        from_disk=True,
        root_dir=tmpdir1,
        valid_size=20,
    )
    dataset = loader._load_dataset()
    assert isinstance(dataset, datasets.Dataset)
    dataset = loader._tokenize(dataset)
    assert isinstance(dataset, TensorDictBase)


@pytest.mark.skipif(
    not (_has_transformers and _has_datasets), reason="missing dependencies"
)
@pytest.mark.parametrize("suffix", ["c", ("c", "d")])
def test_dataset_to_tensordict(tmpdir, suffix):
    import datasets

    dataset = datasets.Dataset.from_dict({"a": np.zeros((10,)), "b": np.ones((10,))})
    td = TokenizedDatasetLoader.dataset_to_tensordict(dataset, tmpdir, prefix=suffix)
    if suffix == "c":
        assert ("c", "a") in td.keys(True)
        assert ("c", "b") in td.keys(True)
    else:
        assert ("c", "d", "a") in td.keys(True)
        assert ("c", "d", "b") in td.keys(True)
    assert isinstance(td.get((suffix, "a")), MemoryMappedTensor)
    assert isinstance(td.get((suffix, "b")), MemoryMappedTensor)


@pytest.mark.skipif(
    not (_has_transformers and _has_datasets), reason="missing dependencies"
)
@pytest.mark.parametrize("batch_size", [5, 6])
@pytest.mark.parametrize("block_size", [15, 50])
@pytest.mark.parametrize(
    "tensorclass_type,dataset",
    [
        (PromptData, "tldr"),
        (PairwiseDataset, "comp"),
    ],
)
@pytest.mark.parametrize("device", get_default_devices())
@pytest.mark.parametrize("split", ["train"])
@pytest.mark.parametrize("infinite", [True, False])
def test_get_dataloader(
    tmpdir1,
    tensorclass_type,
    batch_size,
    block_size,
    device,
    dataset,
    split,
    infinite,
    minidata_dir_tldr,
    minidata_dir_comparison,
):
    if dataset == "tldr":
        dataset = minidata_dir_tldr
    elif dataset == "comp":
        dataset = minidata_dir_comparison
    else:
        raise NotImplementedError
    dl = get_dataloader(
        batch_size,
        block_size,
        tensorclass_type,
        device,
        dataset_name=dataset,
        infinite=infinite,
        prefetch=0,
        split=split,
        root_dir=tmpdir1,
        from_disk=True,
    )
    for data in dl:  # noqa: B007
        break
    assert data.shape[0] == batch_size
    for value in data.values():
        if value.ndim > 1:
            assert value.shape[1] == block_size
    assert data.device == device
    if infinite:
        assert not is_tensor_collection(dl)
    else:
        assert not is_tensor_collection(dl)


@pytest.mark.skipif(
    not (_has_transformers and _has_datasets), reason="missing dependencies"
)
class TestTokenizers:
    @pytest.mark.parametrize("max_length", [10, 15])
    @pytest.mark.parametrize("key", ["text", "other"])
    @pytest.mark.parametrize("padding", ["max_length"])
    @pytest.mark.parametrize("truncation", [True, False])
    @pytest.mark.parametrize("return_tensordict", [True, False])
    @pytest.mark.parametrize("device", get_default_devices())
    def test_tensordict_tokenizer(
        self, max_length, key, padding, truncation, return_tensordict, device
    ):
        from transformers import AutoTokenizer

        tokenizer = AutoTokenizer.from_pretrained("gpt2")
        tokenizer.pad_token = "-pad-"
        process = TensorDictTokenizer(
            tokenizer,
            max_length=max_length,
            key=key,
            padding=padding,
            truncation=truncation,
            return_tensordict=return_tensordict,
            device=device,
        )
        example = {
            key: [
                "Knock, knock.",
                "Who's there?",
                "Lettuce.",
                "Lettuce who?",
                "Lettuce in, it's cold out here!",
            ]
        }
        out = process(example)
        if not truncation and return_tensordict and max_length == 10:
            assert out.get("input_ids").shape[-1] == -1
        elif return_tensordict:
            assert out.get("input_ids").shape[-1] == max_length
        else:
            obj = out.get("input_ids")
            while not isinstance(obj[-1], int):
                obj = obj[-1]
            if not truncation:
                assert len(obj) >= max_length
            else:
                assert len(obj) == max_length

    @pytest.mark.parametrize("max_length", [10, 15])
    @pytest.mark.parametrize("key", ["text", "other"])
    @pytest.mark.parametrize("padding", ["max_length"])
    @pytest.mark.parametrize("truncation", [True, False])
    @pytest.mark.parametrize("return_tensordict", [True, False])
    @pytest.mark.parametrize("device", get_default_devices())
    def test_prompt_tensordict_tokenizer(
        self, max_length, key, padding, truncation, return_tensordict, device
    ):
        from transformers import AutoTokenizer

        tokenizer = AutoTokenizer.from_pretrained("gpt2")
        tokenizer.pad_token = "-pad-"
        process = PromptTensorDictTokenizer(
            tokenizer,
            max_length=max_length,
            key=key,
            padding=padding,
            truncation=truncation,
            return_tensordict=return_tensordict,
            device=device,
        )
        example = {
            key: [
                "Knock, knock.",
                "Who's there?",
                "Lettuce.",
                "Lettuce who?",
                "Lettuce in, it's cold out here!",
            ],
            "label": ["right", "wrong", "right", "wrong", "right"],
        }
        out = process(example)
        if not truncation and return_tensordict and max_length == 10:
            assert out.get("input_ids").shape[-1] == -1
        elif return_tensordict:
            assert out.get("input_ids").shape[-1] == max_length
        else:
            obj = out.get("input_ids")
            while not isinstance(obj[-1], int):
                obj = obj[-1]
            if not truncation:
                assert len(obj) >= max_length
            else:
                assert len(obj) == max_length


@pytest.mark.skipif(
    not (_has_transformers and _has_datasets), reason="missing dependencies"
)
@pytest.mark.parametrize("batch_size", [5, 6])
@pytest.mark.parametrize("block_size", [550, 560])
@pytest.mark.parametrize("device", get_default_devices())
def test_reward_model(tmpdir1, minidata_dir_comparison, batch_size, block_size, device):
    dl = get_dataloader(
        batch_size,
        block_size,
        PairwiseDataset,
        device,
        dataset_name=minidata_dir_comparison,
        infinite=True,
        prefetch=0,
        split="train",
        root_dir=tmpdir1,
        from_disk=True,
    )

    reward_model = GPT2RewardModel().to(device)

    batch = next(dl)
    chosen_rewards, chosen_end_scores = reward_model(
        input_ids=batch.chosen_data.input_ids,
        attention_mask=batch.chosen_data.attention_mask,
    )
    rejected_rewards, _ = reward_model(
        input_ids=batch.rejected_data.input_ids,
        attention_mask=batch.rejected_data.attention_mask,
    )

    assert chosen_rewards.shape == torch.Size([batch_size, block_size])
    assert chosen_end_scores.shape == torch.Size([batch_size])

    batch.chosen_data.rewards = chosen_rewards
    batch.rejected_data.rewards = rejected_rewards

    loss = reward_model.compute_reward_loss(batch.chosen_data, batch.rejected_data)
    assert loss.shape == torch.Size([])


@pytest.mark.skipif(
    not (_has_transformers and _has_datasets), reason="missing dependencies"
)
class TestRollout:
    kl_coef = 0.1

    @staticmethod
    def init_transformer(device="cpu", as_tensordictmodule=True, inference=False):
        from transformers import GPT2Config, GPT2LMHeadModel

        model = GPT2LMHeadModel(GPT2Config())
        model.to(device)

        if as_tensordictmodule:
            model = TensorDictModule(
                model,
                in_keys={
                    "input_ids": "input_ids",
                    "attention_mask": "attention_mask",
                    "labels": "labels",
                },
                out_keys=["logits"] if inference else ["loss", "logits"],
            )
        return model

    @staticmethod
    def init_reward_model(device=None):
        model = GPT2RewardModel()
        model.to(device)

        model = TensorDictModule(
            model,
            in_keys=["input_ids", "attention_mask"],
            out_keys=["rewards", "end_scores"],
        )
        return model

    @staticmethod
    def _get_dummy_batch(batch_dir):
        return TensorDict.load_memmap(batch_dir)

    @property
    def _model(self):
        return self.init_transformer(
            as_tensordictmodule=False,
            inference=True,
        )

    @property
    def _ref_model(self):
        return self.init_transformer(
            as_tensordictmodule=False,
            inference=True,
        )

    @property
    def _reward_model(self):
        return self.init_reward_model()

    def _get_rollout_model(self, max_new_tokens=10):
        return RolloutFromModel(
            model=self._model,
            ref_model=self._ref_model,
            reward_model=self._reward_model,
            max_new_tokens=max_new_tokens,
        )

    def test_padded_right_to_left(self):
        x = torch.arange(12).view(3, 4)
        x[0, -2:] = 100
        x[1, -1:] = 100
        x[2, -3:] = 100
        y = RolloutFromModel._padded_right_to_left(x, eos_token_id=100)
        y_test = torch.tensor([[100, 100, 0, 1], [100, 4, 5, 6], [100, 100, 100, 8]])
        assert (y == y_test).all()

    @pytest.mark.parametrize("right_padded", [False, True])
    @pytest.mark.parametrize("sequence_length", [None, 5])
    def test_padded_left_to_right(self, right_padded, sequence_length):
        x = torch.arange(12).view(3, 4)
        x[0, :2] = 100
        x[1, :1] = 100
        x[2, :3] = 100
        if right_padded:
            x[..., -1] = 100
        y = RolloutFromModel._padded_left_to_right(
            x, eos_token_id=100, sequence_length=sequence_length
        )
        if not right_padded:
            y_test = torch.tensor(
                [[2, 3, 100, 100], [5, 6, 7, 100], [11, 100, 100, 100]]
            )
        else:
            y_test = torch.tensor(
                [[2, 100, 100, 100], [5, 6, 100, 100], [100, 100, 100, 100]]
            )
        if sequence_length:
            y_test = F.pad(y_test, (0, 1), value=100)

        assert (y == y_test).all()

    @pytest.mark.parametrize("batch_size", [2])
    @pytest.mark.parametrize("max_new_tokens", [10])
    @pytest.mark.parametrize("use_max", [True, False])
    def test_get_scores(self, batch_size, max_new_tokens, use_max):
        scores = torch.arange(batch_size * max_new_tokens**2, dtype=torch.float).view(
            batch_size, max_new_tokens, max_new_tokens
        )
        gen_tokens = torch.arange(max_new_tokens).expand(1, max_new_tokens)
        scores_comp = self._get_rollout_model(
            max_new_tokens=max_new_tokens
        )._get_scores(scores.unbind(1), generated_tokens=gen_tokens, use_max=use_max)
        if not use_max:
            assert (
                scores_comp.squeeze()
                == torch.diagonal(scores.log_softmax(-1), 0, -2, -1).squeeze()
            ).all()
        else:
            assert (
                scores_comp.squeeze() == scores.log_softmax(-1)[..., -1].squeeze()
            ).all()

    def test_generate(self, tldr_batch_dir, max_new_tokens=10):
        model = self._get_rollout_model(max_new_tokens)
        batch = self._get_dummy_batch(tldr_batch_dir)
        generated, log_probs, log_ratio = model.generate(batch)
        batch_size = batch.shape[0]

        assert generated.shape == torch.Size(
            [batch_size, batch.input_ids.shape[1] + max_new_tokens]
        )
        assert log_probs.shape == torch.Size([batch_size, max_new_tokens, 1])
        assert (log_probs <= 0).all().item()
        assert log_ratio.shape == torch.Size([batch_size, max_new_tokens])

    def test_rollout_from_data(self, tldr_batch_dir, max_new_tokens=10):
        model = self._get_rollout_model(max_new_tokens)
        batch = self._get_dummy_batch(tldr_batch_dir)
        td = model.rollout_from_data(batch)
        batch_size = batch.shape[0]

        expected_keys = {
            ("next", "attention_mask"),
            ("next", "done"),
            ("next", "terminated"),
            ("next", "input_ids"),
            ("next", "reward"),
            "action",
            "attention_mask",
            "input_ids",
            "sample_log_prob",
        }
        keys = set(td.keys(True, True))
        assert all(key in keys for key in expected_keys)
        assert td.batch_size == torch.Size([batch_size, max_new_tokens])


if __name__ == "__main__":
    args, unknown = argparse.ArgumentParser().parse_known_args()
    pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
